<?php

namespace DataCube\DataCubeAggregation\AI_Toolkit\Regression;

use DataCube\DataCubeAggregation\AI_Toolkit\Interfaces\RubixEstimator;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\NeuralNet\CostFunctions\HuberLoss;
use Rubix\ML\NeuralNet\Optimizers\Adam;
use Rubix\ML\Regressors\Adaline as RubixMLAdaline;
use Symfony\Component\OptionsResolver\OptionsResolver;

class Adaline extends BaseRegression implements RubixEstimator
{
    public $estimator = null;
    public function __construct(array $options = [])
    {
        $resolver = new OptionsResolver();
        $this->configureOptions($resolver);
        $this->options = $resolver->resolve($options);
        if (!empty($this->options['optimizer'])) {
            // optimizer
            if (is_array($this->options['optimizer'])) {
                if (empty($this->options['optimizer'])) {
                    throw new \InvalidArgumentException('You must specify at least one optimizer parameter');
                }
                $this->optimizerStrConvertor();
            }
        }

        /* **
         * int $batchSize = 128,
         * ?Optimizer $optimizer = null,
         * float $l2Penalty = 1e-4,
         * int $epochs = 1000,
         * float $minChange = 1e-4,
         * int $window = 5,
         * ?RegressionLoss $costFn = null
         */
        $this->estimator = new RubixMLAdaline(
            $this->options['batchSize'],
            $this->options['optimizer'],
            $this->options['l2Penalty'],
            $this->options['epochs'],
            $this->options['minChange'],
            $this->options['window'],
            $this->options['costFn'],
        );
    }

    public function configureOptions(OptionsResolver $resolver): void
    {
        $resolver->setDefaults([
            'batchSize' => 256,
            'optimizer' => new Adam(0.001),
            'l2Penalty' => 1e-4,
            'epochs' => 500,
            'minChange' => 1e-6,
            'window' => 5,
            'costFn' => new HuberLoss(2.5),
        ]);
    }

    public function train(array $data, array $labels = [], $verify = true)
    {
        $this->estimator->train(new Labeled($data, $labels));
    }

    public function predict($target)
    {
        return $this->estimator->predict(new Unlabeled([$target]));
    }

}